
options(digits = 4)
library(pacman)
pacman::p_load(
  data.table, tidyverse, lfe, rio, knitr,
  broom, foreach, magrittr, mvtnorm, tictoc
)
library(LalRUtils) # install_github("apoorval/LalRUtils")
set.seed(42)

theme_set(theme_minimal())
options(repr.plot.width = 12, repr.plot.height = 9)

# %%
dgp_zfs = function(
    γ = 0, n = 1e4, share_zfs = 0.2,
    α_0 = -1, π_0 = 2, β_0 = 1, σ_β = 0.25, sed = 42) {
  Z = rbinom(n, 1, 0.5)
  α = rnorm(n, α_0, 1)
  π = runif(n, π_0 - .5, π_0 + .5)
  e = rnorm(n, 0, 1)

  β = rnorm(n, β_0, σ_β)
  d = data.table(Z, α, π, β, e)
  # zero first stage population - last (1-s) obs
  d[round((1 - share_zfs) * n):n, zfs_pop := 1][zfs_pop == 1, π := 0]
  # generate data
  d[
    ,
    D_star := α + π * Z + e
  ][
    ,
    D := ifelse(D_star > 0, 1, 0)
  ][
    ,
    Y := γ * Z + β * D + e
  ]
}

d = dgp_zfs()
# %%
# zero first stage
zfs_fs = felm(D ~ Z, d[!is.na(zfs_pop)])
# reduced form is zero too
zfs_rf = felm(Y ~ Z, d[!is.na(zfs_pop)])


## DGP w Exclusion Restriction violation
d = dgp_zfs(γ = 0.2, β_0 = 2)
rf = felm(Y ~ D, d)
fs = felm(D ~ Z, d)
# zero first stage
zfs_fs = felm(D ~ Z, d[!is.na(zfs_pop)])
# reduced form is zero too
zfs_rf = felm(Y ~ Z, d[!is.na(zfs_pop)])
tsls = felm(Y ~ 1 | 0 | (D ~ Z), d)
b2sls = tsls$coefficients
V2sls = tsls$vcv

ltz = function(μ, σ, b_tsls = b2sls, V_tsls = V2sls, df = d) {
  # prior
  μ_K = c(0, μ); Ω = diag(c(0, σ))
  # data
  Z = cbind(1, df$Z)
  X = cbind(1, df$D)
  # main calculation
  (A = solve(t(X) %*% Z %*% solve(t(Z) %*% Z) %*% t(Z) %*% X) %*% (t(X) %*% Z))
  # point est, var
  pt_est = b_tsls - A %*% μ_K
  var_est = V_tsls + A %*% Ω %*% t(A)
  ltz_se = sqrt(diag(var_est)[2])
  ltz_ci = qnorm(c(0.025, 0.975), pt_est[2], ltz_se)
  ltz_out = c(pt_est[2], ltz_se, ltz_ci)
  return(ltz_out)
}

# %%
(ltz_out = ltz(zfs_rf$coefficients[2], 0.05))
tsls_out = broom::tidy(tsls, conf.int = T) %>%
  filter(term == "`D(fit)`") %>%
  select(estimate, std.error, conf.low:conf.high) %>%
  as.numeric
output = rbind(tsls_out, ltz_out) %>%
  as.data.frame() %>%
  setDT(keep.rownames = T)
colnames(output) = c("Method", "Est", "SE", "Ub", "Lb")

output %>% kable()

# %%
LTZ_sens = function(K, ...) {
  result = matrix(NA, length(K), 5)
  colnames(result) = c("gamma", "tsls_coef", "tsls_se", "ltz_coef", "ltz_se")
  for (i in 1:length(K)) {
    d = dgp_zfs(γ = K[i], ...)
    tsls = felm(Y ~ 1 | 0 | (D ~ Z), d)
    b_2sls = tsls$coefficients
    V_2sls = tsls$vcv
    # zfs pop
    zfs_rf = felm(Y ~ Z, d[!is.na(zfs_pop)])
    # use ltz w prior from zfs pop
    ltz_out = ltz(μ = zfs_rf$coefficients[2], σ = 0.01,
      b_tsls = b_2sls, V_tsls = V_2sls)
    result[i, ] = c(K[i], b_2sls[2], tsls$se[2], ltz_out[1:2])
  }
  return(result)
}
γs = seq(-2, 2, length.out = 20)
sens_out = LTZ_sens(γs, β_0 = 2)
sens_out = as.data.table(sens_out)
sens_out %>% head

# %% fig A9

(p = ggplot(sens_out, aes(x = gamma, y = tsls_coef)) +
  geom_point() +
  geom_hline(yintercept = 2) +
  annotate("text", x = 2, y = 2.2, label = "Truth") +
  geom_ribbon(
    aes(
      ymin = tsls_coef - 1.96 * tsls_se,
      ymax = tsls_coef + 1.96 * tsls_se
    ),
    alpha = 0.4, fill = 'cornflowerblue'
  ) +
  annotate("text", x = 2, y = 6.5, angle = 45, label = "2SLS", colour = 'blue') +
  geom_point(aes(y = ltz_coef)) +
  geom_ribbon(aes(
    ymin = ltz_coef - 1.96 * ltz_se,
    ymax = ltz_coef + 1.96 * ltz_se
  ), alpha = 0.4, fill = 'orangered') +
  annotate("text", x = -2, y = 2.6, label = "LTZ", colour = "orangered") +
  labs(
    title = "LTZ and TSLS coefficients for Exclusion restriction violations of varying severity",
    subtitle = "True effect = 2", y = "coef",
    x = expression(gamma)
  )
)

# %%
ggsave('graphs/FigureA9.pdf', p, width = 10, height = 7)
ltz <- NULL